#!/usr/bin/env python2

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import pickle
import platform
import re

from mullvad import osx_net_services
from mullvad import proc


saved_dns_file = 'dnsconfig.pickle'
saved_resolv_link = '/etc/resolv.conf.pre-mullvad'

ipv4_re = '((\d{1,3}\.){3}\d{1,3})'
ipv6_re = '(([0-9a-f]{0,4}:)+[0-9a-f]{1,4})(%\d+)?'

ip_regex = {
    'ip': ipv4_re,
    'ipv4': ipv4_re,
    'ipv6': ipv6_re,
}


# API ========================================================================


def save():
    """Save the current DNS configuration.

    Store the current DNS configuration in a file in order to be able to
    restore it when needed. Will only create the file if no such file already
    exists.
    """

    if platform.system() == 'Windows' and not os.path.exists(saved_dns_file):
        dns_state = {}
        dns_state['ipv4'] = _win_get_dns_config('ipv4')
        dns_state['ipv6'] = _win_get_dns_config('ipv6')
        with open(saved_dns_file, 'wb') as f:
            pickle.dump(dns_state, f)

    elif platform.system() == 'Darwin' and not os.path.exists(saved_dns_file):
        dns_state = _osx_get_config()

        with open(saved_dns_file, 'wb') as f:
            pickle.dump(dns_state, f)

    elif platform.system() == 'Linux':
        # No need to do anything here, the old resolv.conf in saved
        # when the new one is created in set()
        pass


def restore():
    """Restore the original DNS configuration.

    Load the contents of the file created by save() (if one exists) which
    contains the hosts original DNS settings and use its contents to restore
    the DNS configuration.
    """
    if platform.system() == 'Windows':
        if os.path.exists(saved_dns_file):
            with open(saved_dns_file, 'rb') as f:
                dns_state = pickle.load(f)
                for ipv, config in dns_state.items():
                    _win_set_dns_config(ipv, config)
            os.remove(saved_dns_file)
    elif platform.system() == 'Darwin':
        if os.path.exists(saved_dns_file):
            with open(saved_dns_file, 'rb') as f:
                dns_state = pickle.load(f)
                _osx_set_config(dns_state)
            os.remove(saved_dns_file)
    elif platform.system() == 'Linux':
        if os.path.exists(saved_resolv_link):
            os.rename(saved_resolv_link, '/etc/resolv.conf')


def set(servers):
    """Configure the host to use the given list of DNS servers.

    The list may contain both IPv4 and IPv6 addresses which in the case of
    a host running Windows, will be configured separately.
    """

    ipv4_servers = [s for s in servers if re.match(ipv4_re, s)]
    ipv6_servers = [s for s in servers if re.match(ipv6_re, s)]

    if platform.system() == 'Windows':
        _win_set_dns_servers('ipv4', ipv4_servers)
        _win_set_dns_servers('ipv6', ipv6_servers)

    elif platform.system() == 'Darwin':
        _osx_set_servers(servers)

    elif platform.system() == 'Linux':
        if not os.path.exists(saved_resolv_link):
            os.rename('/etc/resolv.conf', saved_resolv_link)
        with open('/etc/resolv.conf', 'w') as f:
            for server in servers:
                f.write('nameserver {}\n'.format(server))


# Windows ====================================================================

def _win_get_dns_config(ipv):
    """Return the DNS configuration for the given protocol family.

    Create a dictionary for the given IP version containing the hosts
    configured DNS servers for each interface supporting that version. Each
    entry will contain the type of DNS configuration for that interface
    (dhcp or static) as well as a list of the DNS servers which the interface
    is configured to use.
    """

    interfaces = _win_get_interfaces(ipv)

    confs = {}
    for ifc in interfaces:
        command = u'netsh interface {0} show dns name={1}'.format(ipv, ifc)
        out = proc.run_assert_ok(command.split())

        conf = {'source': u'static', 'servers': []}
        for line in out.splitlines():
            if u'DHCP' in line:
                conf['source'] = u'dhcp'

            match = re.search(ip_regex[ipv], line)
            if match is not None:
                conf['servers'].append(match.group(1))

        confs[ifc] = conf

    return confs


def _win_set_dns_config(ipv, config):
    """Configure the hosts DNS settings.

    Set the DNS configuration for the given IP version to match that which
    is given in the 'config' argument which is a dictionary of the same
    structure as the one returned by _win_get_dns_config().
    """
    commands = []
    for ifc, conf in config.items():
        if conf['source'] == u'dhcp':
            cmd = u'netsh interface {} set dns name={} source=dhcp validate=no'
            commands.append(cmd.format(ipv, ifc))
        elif conf['source'] == u'static':
            cmd = (u'netsh interface {} {} dns name={} '
                   'source=static addr={} validate=no')
            if len(conf['servers']) == 0:
                commands.append(cmd.format(ipv, 'set', ifc, 'none'))
            else:
                commands.append(cmd.format(ipv, 'set', ifc,
                                           conf['servers'][0]))

                cmd = u'netsh interface {} {} dns name={} addr={} validate=no'
                for server in conf['servers'][1:]:
                    commands.append(cmd.format(ipv, 'add', ifc, server))

    for command in commands:
        proc.run_assert_ok(command.split())


def _win_set_dns_servers(ipv, servers):
    """Configure the interfaces supporting the given IP version to use the
    given list of IP addresses as its DNS servers.
    """
    config = {}
    for ifc in _win_get_interfaces(ipv):
        config[ifc] = {'source': u'static', 'servers': servers}
    _win_set_dns_config(ipv, config)


def _win_get_interfaces(ipv):
    """Get the index num of all interfaces supporting the given IP version."""
    ret = []
    command = u'netsh interface {} show interfaces'.format(ipv)
    out = proc.run_assert_ok(command.split())
    for line in out.splitlines():
        if 'isatap' not in line and 'Teredo' not in line:
            try:
                if_idx = int(line.split()[0])
            except (IndexError, ValueError):
                continue
            if if_idx != 1:
                ret.append(if_idx)
    return ret


# OSX ========================================================================


def _osx_get_dns_servers(service):
    """Return a list of the DNS servers used by the given network service."""
    out = proc.run_assert_ok(['networksetup', '-getdnsservers', service])
    if 'DNS' in out:
        return ['empty']
    else:
        return [s for s in out.splitlines() if 'currently disabled' not in s]


def _osx_set_servers(servers):
    """Configure all active network services to use the given DNS servers."""
    for service in osx_net_services.get_services():
        command = ['networksetup', '-setdnsservers', service]
        command += servers if len(servers) > 0 else ['empty']
        proc.run_assert_ok(command)


def _osx_get_config():
    """Return the DNS configuration for all active network services."""
    confs = {}
    for service in osx_net_services.get_services():
        servers = _osx_get_dns_servers(service)
        confs[service] = servers
    return confs


def _osx_set_config(config):
    """Set the hosts DNS configuration to match the given configuration."""
    for service, servers in config.items():
        command = ['networksetup', '-setdnsservers', service] + servers
        proc.run_assert_ok(command)


# ============================================================================

def test():
    print('System:', platform.system())

    if platform.system() == 'Windows':
        print('Running DNS configuration test for Windows...')
        families = ['ip']
        families.append('ipv6')

        servers = ['8.8.4.4', '8.8.8.8',
                   '2001:4860:4860::8844', '2001:4860:4860::8888']

        for fam in families:
            print('Testing', fam)
            old_state = _win_get_dns_config(fam)
            save()
            assert os.path.exists(saved_dns_file)
            set(servers)
            for name, conf in _win_get_dns_config(fam).items():
                if fam == 'ip':
                    fam_servers = \
                        [s for s in servers if re.match(ipv4_re, s)]
                else:
                    fam_servers = \
                        [s for s in servers if re.match(ipv6_re, s)]

                assert conf['servers'] == fam_servers
            restore()
            assert not os.path.exists(saved_dns_file)
            new_state = _win_get_dns_config(fam)
            assert new_state == old_state
            print('Test successful\n\n')

    elif platform.system() == 'Darwin':
        print('Running DNS configuration test for OSX...')
        old_state = _osx_get_config()
        save()
        assert os.path.exists(saved_dns_file)
        set(['8.8.4.4', '8.8.8.8'])
        restore()
        assert not os.path.exists(saved_resolv_link)
        assert _osx_get_config() == old_state
        print('Test successful\n\n')

    else:
        print('Running DNS configuration test for Linux...')
        old_state = open('/etc/resolv.conf', 'r').read()
        set(['8.8.4.4', '8.8.8.8'])
        assert os.path.exists(saved_resolv_link)
        restore()
        assert not os.path.exists(saved_resolv_link)
        assert open('/etc/resolv.conf', 'r').read() == old_state
        print('Test successful\n\n')


if __name__ == '__main__':
    test()
